import torch
import numpy as np
import argparse
from utils.load_dataset import *
from utils.instantiate_model import *
from utils.str2bool import str2bool

parser = argparse.ArgumentParser(description='Train', formatter_class=argparse.ArgumentDefaultsHelpFormatter)

parser.add_argument('--parallel',               default=False,          type=str2bool,  help='Device in  parallel')

# Dataloader args
parser.add_argument('--train_batch_size',       default=512,            type=int,       help='Train batch size')
parser.add_argument('--test_batch_size',        default=512,            type=int,       help='Test batch size')
parser.add_argument('--val_split',              default=0.1,            type=float,     help='Fraction of training dataset split as validation')
parser.add_argument('--augment',                default=True,           type=str2bool,  help='Random horizontal flip and random crop')
parser.add_argument('--padding_crop',           default=4,              type=int,       help='Padding for random crop')
parser.add_argument('--shuffle',                default=True,           type=str2bool,  help='Shuffle the training dataset')
parser.add_argument('--random_seed',            default=0,              type=int,       help='Initialising the seed for reproducibility')
parser.add_argument('--arch',                   default='resnet18',     type=str,       help='Network architecture')
parser.add_argument('--suffixs',                default='1,2,3,4,5',    type=str,       help='Model suffixs')
parser.add_argument('--outfile',                default='sep_mean.txt',      type=str,       help='Name of the output file to store results from comparison with other OoD methods')

global args
args = parser.parse_args()
print(args)

out_file = open(args.outfile, "w")

# Setup right device to run on
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

in_datasets = ['cifar10', 'cifar100', 'svhn', 'tinyimagenet']
ood_datasets = ['g-noise', 'noise', 'svhn', 'cifar100', 'textures', 'lsun', 'tinyimagenet', 'places365']

print('\n')


def get_test_data(ood_dataset, dataset, mode, suffix, net=None, examples_per_class=1000):
    if(ood_dataset == 'noise'):
        uni_rand = np.random.uniform(0, 1, (examples_per_class, dataset.img_ch, dataset.img_dim, dataset.img_dim))
        data_rand = torch.Tensor(uni_rand).to(device)
        with torch.no_grad():
            _, fet = net(data_rand, latent=True)

        ood_data = fet.unsqueeze(0)
    elif(ood_dataset == 'g-noise'):
        uni_rand = np.random.normal(0, 1, (examples_per_class, dataset.img_ch, dataset.img_dim, dataset.img_dim))
        data_rand = torch.Tensor(uni_rand).to(device)
        with torch.no_grad():
            _, fet = net(data_rand, latent=True)

        ood_data = fet.unsqueeze(0)
    elif(ood_dataset == dataset.name):
        path = './outputs/latent_space/{}_{}_{}.vec'.format(args.dataset.lower(), args.arch, algo + '_' + str(suffix))
        ood_data = torch.load(path)

        # Since requested == in-dist dataset return label as +1
        # 0 ----> is in-dist
        # 1 ----> out-of-distribution
        return ood_data
    else:
        path = './outputs/latent_space/{}_on_{}_{}_{}.vec'.format(ood_dataset.lower(),
                                                     dataset.name.lower(),
                                                     args.arch,
                                                     mode + '_'+ str(suffix))
        ood_data = torch.load(path)

    # 0 ----> is in-dist
    # 1 ----> out-of-distribution
    return ood_data

def get_metric(in_data, ood_dataset, mode, suffix):
    # Instantiate model 
    net, _ = instantiate_model(dataset=dataset,
                               arch=args.arch,
                               suffix=mode + "_" + suffix,
                               load=True,
                               torch_weights=False,
                               device=device,
                               verbose=False)

    net.eval()
    num_classes = dataset.num_classes
    classes = list(range(num_classes))

    ood_data = get_test_data(ood_dataset, 
                             dataset,
                             mode=mode,
                             suffix=suffix,
                             net=net, 
                             examples_per_class=in_data.shape[1])

    ood_data = ood_data.reshape(-1, 512).cpu()
    in_data = in_data.reshape(-1, 512).cpu()

    in_mean = in_data.mean(0)
    out_mean = ood_data.mean(0)

    in_var = in_data.std() #** 2
    ood_var = ood_data.std() #** 2

    diff = in_mean - out_mean
    mean_dis = torch.sqrt(torch.matmul(diff, diff.T))

    in_data = in_data.cpu().numpy()
    ood_data = ood_data.cpu().numpy()
    w = net.linear.weight.cpu().detach().numpy()
    b = net.linear.bias.cpu().detach().numpy()

    #normalize the weights
    w = w / np.linalg.norm(w, axis=1, keepdims=True)

    # normalize in_data and data_out using l2 norm  
    in_data = in_data / np.linalg.norm(in_data, axis=1, keepdims=True)
    ood_data = ood_data / np.linalg.norm(ood_data, axis=1, keepdims=True)

    in_data = np.matmul(in_data, w.T) #+ b
    ood_data = np.matmul(ood_data, w.T) #+ b    

    # Get max of each row
    in_data = np.max(in_data, axis=1)
    ood_data = np.max(ood_data, axis=1)

    # get angle using cosine inverse of in_data and ood_data
    in_data = np.arccos(in_data)
    ood_data = np.arccos(ood_data)

    in_mean_angle = np.mean(in_data)
    ood_mean_angle = np.mean(ood_data)

    in_std_angle = np.std(in_data)
    ood_std_angle = np.std(ood_data)

    return mean_dis, in_var, ood_var, in_mean_angle, in_std_angle, ood_mean_angle, ood_std_angle


def get_metric_for_dataset(in_dataset):

    suffixs = args.suffixs.split(',')

    print('------------------------------------------------------------------------------', file=out_file)
    print('\t' + dataset.name, file=out_file)
    print('------------------------------------------------------------------------------', file=out_file)
    descrption_list = ['OoD Dataset', 'Distance', "OoD Var\t", "In Var\t", "In Angle", "OoD Angle", "Var Ratio", "Dis Ratio", "X/Y"]
    print('\t\t|\t\t'.join(descrption_list), file=out_file)
    # print("               \t\t|\tERM\t\tVRM\t\t|\tERM\t\tVRM\t\t|\tERM\t\tVRM\t\t|\t\t\t\t|\t\t\t\t|\t\t\t\t|", file=out_file)
    ref_str = "{: <12}\t|" + '\t{:.4f}\t\t{:.4f}\t|' * 5 + "\t\t{:.4f}\t\t\t|" * 3
    sum_erm = np.array([])
    sum_vrm = np.array([])
    no_ood_datasets = 0
    for ood_dataset in ood_datasets:
        if(ood_dataset.lower() == in_dataset.name):
            continue

        if('cifar' in in_dataset.name and 'cifar' in ood_dataset.lower()):
            continue

        vrm = []
        erm = []
        no_ood_datasets +=1 
        for algo in ['vrm', 'erm']:
            for suffix in suffixs:
                path = './outputs/latent_space/{}_{}_{}.vec'.format(in_dataset.name, args.arch, algo + '_' + str(suffix))
                in_data = torch.load(path)
                mean_dis, in_var, ood_var, in_mean_angle, in_std_angle, ood_mean_angle, ood_std_angle = get_metric(in_data,
                                                                                                                   ood_dataset,
                                                                                                                   mode=algo,
                                                                                                                   suffix=str(suffix))

                if algo == 'vrm':
                    vrm.append([mean_dis, ood_var, in_var, in_mean_angle, in_std_angle, ood_mean_angle, ood_std_angle])
                else:
                    erm.append([mean_dis, ood_var, in_var, in_mean_angle, in_std_angle, ood_mean_angle, ood_std_angle])

        vrm = np.array(vrm)
        erm = np.array(erm)
        vrm_values = vrm.mean(axis=0)
        erm_values = erm.mean(axis=0)
        dis_ratio = vrm_values[0] / erm_values[0]
        ood_ratio = vrm_values[1] / erm_values[1]
        sep_index = dis_ratio / ood_ratio
        if(len(sum_vrm) == 0):
            sum_vrm = np.expand_dims(np.append(vrm_values, np.array([dis_ratio, ood_ratio, sep_index])), 0)
            sum_erm = np.expand_dims(erm_values, 0)
        else:
            sum_vrm = np.append(sum_vrm, np.expand_dims(np.append(vrm_values, np.array([ood_ratio, dis_ratio, sep_index])),0), axis=0)
            sum_erm = np.append(sum_erm, np.expand_dims(erm_values, 0), axis=0)

        print(ref_str.format(ood_dataset,
                             erm_values[0],
                             vrm_values[0],
                             erm_values[1],
                             vrm_values[1],
                             erm_values[2],
                             vrm_values[2],
                             erm_values[3],
                             vrm_values[3],
                             erm_values[5],
                             vrm_values[5],
                             ood_ratio,
                             dis_ratio,
                             sep_index,
                             )
             , file=out_file)


    sum_vrm = sum_vrm.mean(axis=0)
    sum_erm = sum_erm.mean(axis=0)
    dis_ratio = sum_vrm[0] / sum_erm[0]
    ood_ratio = sum_vrm[1] / sum_erm[1]
    sep_index = dis_ratio / ood_ratio
    print(ref_str.format("Average",
                          sum_erm[0],
                          sum_vrm[0],
                          sum_erm[1],
                          sum_vrm[1],
                          sum_erm[2],
                          sum_vrm[2],
                          sum_erm[3],
                          sum_vrm[3],
                          sum_erm[5],
                          sum_vrm[5],
                          ood_ratio,
                          dis_ratio,
                          sep_index)
        , file=out_file)
    print("", file=out_file)
    
    return vrm, erm


for in_dataset in in_datasets:
    dataset = load_dataset(dataset=in_dataset,
                           train_batch_size=args.train_batch_size,
                           test_batch_size=args.test_batch_size,
                           val_split=args.val_split,
                           augment=args.augment,
                           padding_crop=args.padding_crop,
                           shuffle=args.shuffle,
                           random_seed=args.random_seed,
                           device=device)

    vrm, erm = get_metric_for_dataset(dataset)

out_file.close()
